package org.hepeng.workx.mybatis.mapper;

import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.reflect.FieldUtils;
import org.hepeng.workx.mybatis.util.WorkXMybatisEnvironment;
import org.hibernate.cfg.Environment;
import org.hibernate.dialect.Database;
import org.hibernate.dialect.Dialect;
import org.hibernate.sql.Delete;
import org.hibernate.sql.Insert;
import org.hibernate.sql.Select;
import org.hibernate.sql.SelectFragment;
import org.hibernate.sql.Update;

import java.lang.reflect.Field;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Properties;
import java.util.Set;

/**
 * @author he peng
 */
public class HibernateMappedStatementSupplier extends AbstractMappedStatementSupplier {

    private Dialect dialect;

    public HibernateMappedStatementSupplier() {
        Database database = Database.valueOf(StringUtils.upperCase(
                WorkXMybatisEnvironment.getDialect(null)));
        if (Objects.nonNull(database)) {
            Class<? extends Dialect> dialectClass = database.latestDialect();
            Properties props = new Properties();
            props.setProperty(Environment.DIALECT , dialectClass.getName());
            this.dialect = Dialect.getDialect(props);
        }
    }

    @Override
    protected String createAllColumnInsertSQL() {
        Insert insert = new Insert(dialect);
        insert.setTableName(tableMapMetaData.getTableName());

        for (String identityColumnName : tableMapMetaData.getIdentityColumnNames()) {
            insert.addIdentityColumn(identityColumnName);
        }

        for (Map.Entry<String, String> entry : tableMapMetaData.getColumnToFieldMap().entrySet()) {
            insert.addColumn(entry.getKey() , "#{" + entry.getValue() + "}");
        }

        return insert.toStatementString();
    }

    @Override
    protected String createSelectiveColumnInsertSQL() {
        Insert insert = new Insert(dialect);
        insert.setTableName(tableMapMetaData.getTableName());

        Set<Map.Entry<String, String>> entrySet = tableMapMetaData.getColumnToFieldMap().entrySet();
        Map.Entry<String, String>[] entries = new Map.Entry[entrySet.size()];
        entrySet.toArray(entries);

        StringBuilder columnBuilder = new StringBuilder();
        StringBuilder valueExpressionBuilder = new StringBuilder();
        for (int i = 0 ; i < entries.length ; i++ ) {
            Map.Entry<String, String> entry = entries[i];
            String column = entry.getKey();
            String field = entry.getValue();

            if (i == 0) {
                columnBuilder
                        .append("<trim prefix=\"\" suffix=\"\" suffixOverrides=\",\">");
                valueExpressionBuilder
                        .append("<trim prefix=\"\" suffix=\"\" suffixOverrides=\",\">");
            }

            columnBuilder
                    .append("<if test=\"" + field + " != null\">\n" +
                            "   " + column + "," +
                            "</if>");
            valueExpressionBuilder
                    .append("<if test=\"" + field + " != null\">\n" +
                            "   #{" + field + "}," +
                            "</if>");

            if (i == entries.length - 1) {
                columnBuilder.append("</trim>");
                valueExpressionBuilder.append("</trim>");
            }
        }

        insert.addColumn(columnBuilder.toString() , valueExpressionBuilder.toString());
        return scriptWrap(insert.toStatementString());
    }

    @Override
    protected String createAllColumnUpdateByIdSQL() {
        Update update = new Update(dialect);
        update.setTableName(tableMapMetaData.getTableName());

        for (Map.Entry<String, String> entry : tableMapMetaData.getColumnToFieldMap().entrySet()) {
            if (tableMapMetaData.getIdentityColumnNames().contains(entry.getKey())) {
                update.addPrimaryKeyColumn(entry.getKey() , "#{" + entry.getValue() + "}");
            } else {
                update.addColumn(entry.getKey() , "#{" + entry.getValue() + "}");
            }
        }
        return update.toStatementString();
    }

    @Override
    protected String createSelectiveColumnUpdateByIdSQL() {
        Update update = new Update(dialect);
        update.setTableName(tableMapMetaData.getTableName());
        Set<Map.Entry<String, String>> entrySet = tableMapMetaData.getColumnToFieldMap().entrySet();
        Map.Entry<String, String>[] entries = new Map.Entry[entrySet.size()];
        entrySet.toArray(entries);

        StringBuilder columnBuilder = new StringBuilder();

        for (int i = 0 ; i < entries.length ; i++ ) {
            Map.Entry<String, String> entry = entries[i];
            String column = entry.getKey();
            String field = entry.getValue();
            if (i == 0) {
                columnBuilder
                        .append("<set>");
            }

            if (tableMapMetaData.getIdentityColumnNames().contains(column)) {
                update.addPrimaryKeyColumn(column , "#{" + field + "}");
            } else {
                columnBuilder
                        .append("<if test=\"" + field + " != null\">\n")
                        .append("   " + column + " = #{"+ field + "},")
                        .append("</if>");
            }

            if (i == entries.length - 1) {
                columnBuilder.append("</set>");
            }
        }

        update.addColumn(columnBuilder.toString() , "");
        String sqlStatement = update.toStatementString();
        sqlStatement = StringUtils.removeFirst(sqlStatement, "set");
        String sqlStatementWhereBefore = StringUtils.substringBefore(sqlStatement, "where");
        String sqlStatementWhereAfter = StringUtils.substringAfter(sqlStatement, "where");
        sqlStatementWhereBefore = StringUtils.removeEnd(sqlStatementWhereBefore.trim() , "=");
        sqlStatement = sqlStatementWhereBefore + " where " + sqlStatementWhereAfter;
        return scriptWrap(sqlStatement);
    }

    @Override
    protected String createDeleteByIdAtPhysicalSQL() {
        Delete delete = new Delete();
        delete.setTableName(tableMapMetaData.getTableName());
        for (Map.Entry<String, String> entry : tableMapMetaData.getIdentityColumnToFieldMap().entrySet()) {
            delete.addPrimaryKeyColumn(entry.getKey() , "#{primaryKey}");
        }
        return delete.toStatementString();
    }

    @Override
    protected String createDeleteByIdAtLogicSQL() {
        String deleted;
        String deletedFieldName;
        List<Field> logicDeleteFields = FieldUtils.getFieldsListWithAnnotation(super.tableMapMetaData.getEntityClass(), LogicDelete.class);
        if (CollectionUtils.isEmpty(logicDeleteFields)) {
            deleted = WorkXMybatisEnvironment.getLogicDeleted();
            deletedFieldName = "";
        } else if (logicDeleteFields.size() > 1) {
            throw new IllegalStateException("There is more than one @" + LogicDelete.class + " annotation in the " + super.tableMapMetaData.getEntityClass());
        } else {
            Field deletedField = logicDeleteFields.get(0);
            LogicDelete logicDelete = deletedField.getAnnotation(LogicDelete.class);
            deleted = logicDelete.deleted();
            deletedFieldName = deletedField.getName();
        }

        Update update = new Update(dialect);
        update.setTableName(tableMapMetaData.getTableName());

        for (Map.Entry<String, String> entry : tableMapMetaData.getColumnToFieldMap().entrySet()) {
            if (tableMapMetaData.getIdentityColumnNames().contains(entry.getKey())) {
                update.addPrimaryKeyColumn(entry.getKey() , "#{primaryKey}");
            } else {
                if (StringUtils.equals(deletedFieldName , entry.getValue())) {
                    update.addColumn(entry.getKey() , deleted);
                }
            }
        }
        return update.toStatementString();
    }

    @Override
    protected String createSelectByIdSQL() {
        Select select = new Select(dialect);

        SelectFragment selectFragment = new SelectFragment();
        for (Map.Entry<String, String> entry : tableMapMetaData.getColumnToFieldMap().entrySet()) {
            selectFragment.addColumn(tableMapMetaData.getTableName() , entry.getKey() , entry.getValue());
        }

        select.setSelectClause(selectFragment)
                .setFromClause(tableMapMetaData.getTableName());
        for (Map.Entry<String, String> entry : tableMapMetaData.getIdentityColumnToFieldMap().entrySet()) {
            select.setWhereClause(entry.getKey() + " = " + "#{primaryKey}");
        }
        return select.toStatementString();
    }

    @Override
    protected String createSelectByDynamicWhereSQL() {
        Select select = new Select(dialect);

        SelectFragment selectFragment = new SelectFragment();
        for (Map.Entry<String, String> entry : tableMapMetaData.getColumnToFieldMap().entrySet()) {
            selectFragment.addColumn(tableMapMetaData.getTableName() , entry.getKey() , entry.getValue());
        }

        select.setSelectClause(selectFragment)
                .setFromClause(tableMapMetaData.getTableName());

        Set<Map.Entry<String, String>> entrySet = tableMapMetaData.getColumnToFieldMap().entrySet();
        Map.Entry<String, String>[] entries = new Map.Entry[entrySet.size()];
        entrySet.toArray(entries);

        StringBuilder whereBuilder = new StringBuilder(" ");
        for (int i = 0 ; i < entries.length ; i++ ) {
            Map.Entry<String, String> entry = entries[i];
            String column = entry.getKey();
            String field = entry.getValue();

            if (i == 0) {

                whereBuilder.append("<where> ");
            }

            whereBuilder
                    .append(" <if test=\"" + field + " != null\">\n" +
                            " AND " + column + " = " + "#{" + field + "} " +
                            "</if> ");

            if (i == entries.length - 1) {
                whereBuilder.append(" </where>");
            }
        }

        select.setWhereClause(whereBuilder.toString());
        String statementString = StringUtils.removeFirst(select.toStatementString() , "where");
        return scriptWrap(statementString);
    }

    private String scriptWrap(String sqlStatement) {
        StringBuilder sqlScriptBuilder = new StringBuilder("<script>");
        sqlScriptBuilder
                .append(sqlStatement)
                .append("</script>");
        return sqlScriptBuilder.toString();
    }
}
